import ephere_moov as moov

import sys
from string import Template

_attractMagnitudeMultiplier = 100

_defaultInputParams = {
			   'Gravity': -981.0,
			   'GravityScale': 1.0,
			   'Drag': 0.01,
			   }

_isRunning3dsMax = '3dsmax' in sys.executable.lower()
_gravityAxis = 2 if _isRunning3dsMax else 1

class OpenCLKernels:
	UpdateGravityDrag_GravityAxis = Template( '''
	kernel void update( global Moov_ParticleDescription* pd,				// 0
						float step,											// 1
						float gravity,										// 2
						float velocityMultiplier							// 3
						)
	{
		int i = get_global_id( 0 );
		pd[i].v[0] *= velocityMultiplier;
		pd[i].v[1] *= velocityMultiplier;
		pd[i].v[2] *= velocityMultiplier;
		pd[i].v[$gravityAxis] += step * gravity;
	}
	''' )

	UpdateGravityDrag = UpdateGravityDrag_GravityAxis.substitute( gravityAxis = _gravityAxis )

	UpdateWithForces_GravityAxis = Template( '''
	kernel void update( /*  0 */ global Moov_ParticleDescription* pd,
						/*  1 */ float step,
						/*  2 */ float gravity,
						/*  3 */ float velocityMultiplier,
						/*  4 */ int hasParticleForces,
						/*  5 */ global Moov_Vector3* particleForces,
						/*  6 */ int hasAttractForces,
						/*  7 */ float attractStiffness,
						/*  8 */ global Moov_Vector3* attractPositions,
						/*  9 */ int useAttractMultipliers,
						/* 10 */ global float* attractMultipliers )
	{
		int i = get_global_id( 0 );
		pd[i].v[0] *= velocityMultiplier;
		pd[i].v[1] *= velocityMultiplier;
		pd[i].v[2] *= velocityMultiplier;
		pd[i].v[$gravityAxis] += step * gravity;
		if( hasParticleForces )
		{
			float mul = step; /* / pd[i].mass; */
			pd[i].v[0] += particleForces[i][0] * mul;
			pd[i].v[1] += particleForces[i][1] * mul;
			pd[i].v[2] += particleForces[i][2] * mul;
		}
		if( hasAttractForces )
		{
			float mul = attractStiffness * step; /* / pd[i].mass; */
			mul *= useAttractMultipliers == 0 ? 1.0f : attractMultipliers[i];
			pd[i].v[0] += ( attractPositions[i][0] - pd[i].x[0] ) * mul;
			pd[i].v[1] += ( attractPositions[i][1] - pd[i].x[1] ) * mul;
			pd[i].v[2] += ( attractPositions[i][2] - pd[i].x[2] ) * mul;
		}
	}
	''' )

	UpdateWithForces = UpdateWithForces_GravityAxis.substitute( gravityAxis = _gravityAxis )

	UpdateRoots = '''
	kernel void updateRoots( global Moov_ParticleDescription* pd,
						global Moov_Vector3* rootPositions )
	{
		int i = get_global_id( 0 );
		pd[i].x[0] = rootPositions[i][0];
		pd[i].x[1] = rootPositions[i][1];
		pd[i].x[2] = rootPositions[i][2];
	}
	'''

	UpdateRootsWithOrientations = '''
	kernel void updateRoots( global Moov_ParticleDescription* pd,
						global Moov_Vector3* rootPositions, global Moov_Quaternion* rootOrientations )
	{
		int i = get_global_id( 0 );
		pd[i].x[0] = rootPositions[i][0];
		pd[i].x[1] = rootPositions[i][1];
		pd[i].x[2] = rootPositions[i][2];
		pd[i].rotation[0] = rootOrientations[i][0];
		pd[i].rotation[1] = rootOrientations[i][1];
		pd[i].rotation[2] = rootOrientations[i][2];
		pd[i].rotation[3] = rootOrientations[i][3];
	}
	'''

	UpdateMases = '''
	kernel void updateMasses( global Moov_ParticleDescription* pd, global float* masses )
	{
		int i = get_global_id( 0 );
		pd[i].mass = masses[i];
	}
	'''

class SolverUpdater_OpenCL:
	"""Class containing solver Update functions with OpenCL enabled."""
	def __init__( self, solver ):
		self.solver = solver
		self.updateKernel_GravityDrag = solver.CreateKernel( OpenCLKernels.UpdateGravityDrag, True )
		self.updateKernel_WithForces = solver.CreateKernel( OpenCLKernels.UpdateWithForces, True )
		self.updateKernel = self.updateKernel_GravityDrag
		self.updateRoots_Positions = solver.CreateKernel( OpenCLKernels.UpdateRoots, True )
		self.updateRoots_WithOrientations = solver.CreateKernel( OpenCLKernels.UpdateRootsWithOrientations, True )
		self.updateRootsKernel = self.updateRoots_Positions
		self.updateMassesKernel = solver.CreateKernel( OpenCLKernels.UpdateMases, True )

	def Reset( self, dynamicParticleSet, rootParticleSet ):
		pass

	def SetUpdateKernelParams( self, gravity = None, drag = None ):
		# Initialize kernel parameters
		Gravity = gravity if gravity is not None else _defaultInputParams['Gravity']
		Drag = drag if drag is not None else _defaultInputParams['Drag']
		argTypes = ['None', 'float', 'float', 'float']
		argValues = [0, 0, Gravity, 1.0 - Drag]
		self.updateKernel_GravityDrag.SetArgs( argTypes, argValues )
		self.updateKernel_WithForces.SetArgs( argTypes, argValues )
		self.updateKernel = self.updateKernel_GravityDrag

	def SetForces( self, forces = None, attractStiffness = None, initialPositions = None, rampMultipliers = None  ):
		self.updateKernel = self.updateKernel_WithForces
		hasParticleForces = forces is not None
		self.updateKernel.SetArg( 4, hasParticleForces )
		self.updateKernel.SetBufferArg( 5, forces if hasParticleForces else [ moov.Vector3( 0, 0, 0 ) ] )

		hasAttractForces = initialPositions is not None and attractStiffness is not None
		self.updateKernel.SetArg( 6, hasAttractForces )
		self.updateKernel.SetArg( 7, attractStiffness * _attractMagnitudeMultiplier if hasAttractForces else 0 )
		self.updateKernel.SetBufferArg( 8, initialPositions if hasAttractForces else [ moov.Vector3( 0, 0, 0 ) ] )

		useRampMultipliers = 1 if rampMultipliers is not None and len( rampMultipliers ) > 1 else 0
		self.updateKernel.SetArg( 9, useRampMultipliers )
		self.updateKernel.SetBufferArg( 10, rampMultipliers if useRampMultipliers else [0.0] )

	def UpdateRoots( self, rootParticleSet, rootPositions, rootOrientations = None ):
		if rootOrientations is None:
			self.updateRootsKernel = self.updateRoots_Positions
			updatedInformation = moov.ParticleInformation.Position
		else:
			self.updateRootsKernel = self.updateRoots_WithOrientations
			updatedInformation = moov.ParticleInformation( moov.ParticleInformation.Position | moov.ParticleInformation.Rotation )
			self.updateRootsKernel.SetBufferArg( 2, rootOrientations )
		self.updateRootsKernel.SetBufferArg( 1, rootPositions )
		self.solver.UpdateParticlesCL( rootParticleSet, self.updateRootsKernel, updatedInformation )

	def UpdateParticles( self, step, particleSet ):
		# Update dynamic particles
		self.updateKernel.SetArg( 1, step )
		self.solver.UpdateParticlesCL( particleSet, self.updateKernel, moov.ParticleInformation( moov.ParticleInformation.Mass | moov.ParticleInformation.Position | moov.ParticleInformation.Velocity ) )

	def UpdateMasses( self, particleSet, masses ):
		self.updateMassesKernel.SetBufferArg( 1, masses )
		self.solver.UpdateParticlesCL( particleSet, self.updateMassesKernel, moov.ParticleInformation.Mass )


class SolverUpdater:
	"""Class containing solver Update functions with using non-OpenCL solver interface."""
	def __init__( self, solver ):
		self.solver = solver

	@staticmethod
	def _Update( pd, step, gravity, velocityMultiplier, attractMagnitude, initialParticlePositions, particleForces, solverIdToParticleIndexMap, rampMultipliers ):
		if pd.mass != 0:
			pd.v[0] *= velocityMultiplier
			pd.v[1] *= velocityMultiplier
			pd.v[2] *= velocityMultiplier
			pd.v[_gravityAxis] += step * gravity

		particleIndex = solverIdToParticleIndexMap[pd.id]
		if initialParticlePositions is not None:
			mul = attractMagnitude * step # / pd.mass
			if rampMultipliers is not None:
				mul *= rampMultipliers[particleIndex]
			pd.v[0] += ( initialParticlePositions[particleIndex][0] - pd.x[0] ) * mul
			pd.v[1] += ( initialParticlePositions[particleIndex][1] - pd.x[1] ) * mul
			pd.v[2] += ( initialParticlePositions[particleIndex][2] - pd.x[2] ) * mul

		if particleForces is not None:
			mul = step # / pd.mass
			pd.v[0] += particleForces[particleIndex][0] * mul
			pd.v[1] += particleForces[particleIndex][1] * mul
			pd.v[2] += particleForces[particleIndex][2] * mul

	@staticmethod
	def _UpdateRoots( pd, rootPositions, solverIdToParticleIndexMap, rootOrientations ):
		particleIndex = solverIdToParticleIndexMap[pd.id]
		pd.x = rootPositions[particleIndex]
		if rootOrientations is not None:
			pd.rotation = rootOrientations[particleIndex]

	@staticmethod
	def _UpdateMass( pd, masses, solverIdToParticleIndexMap ):
		particleIndex = solverIdToParticleIndexMap[pd.id]
		pd.mass = masses[particleIndex]

	def CreateIdToParticleIndexMap( self, particleSet ):
		particleData = self.solver.GetParticleInformation( particleSet, moov.ParticleInformation.Id )
		result = {}
		for index in range( len( particleData ) ):
			result[particleData[index].id] = index
		return result

	def Reset( self, dynamicParticleSet, rootParticleSet ):
		self.rootIdToIndexMap = self.CreateIdToParticleIndexMap( rootParticleSet )
		self.dynamicIdToIndexMap = self.CreateIdToParticleIndexMap( dynamicParticleSet )

	def SetUpdateKernelParams( self, gravity = None, drag = None):
		# Initialize kernel parameters
		self.Gravity = gravity if gravity is not None else _defaultInputParams['Gravity']
		self.Drag = drag if drag is not None else _defaultInputParams['Drag']
		self.DragComplement = 1 - self.Drag
		self.SetForces()

	def SetForces( self, particleForces = None, attractMagnitude = None, initialPositions = None, rampMultipliers = None  ):
		self.particleForces = particleForces
		self.attractMagnitude = attractMagnitude * _attractMagnitudeMultiplier if attractMagnitude is not None else 0
		self.initialPositions = initialPositions
		self.rampMultipliers = rampMultipliers if rampMultipliers is not None and len( rampMultipliers ) > 1 else None

	def UpdateRoots( self, rootParticleSet, rootPositions, rootOrientations = None ):
		self.solver.UpdateParticles( rootParticleSet, lambda pd: \
			SolverUpdater._UpdateRoots( pd, rootPositions, self.rootIdToIndexMap, rootOrientations ), \
			moov.ParticleInformation( moov.ParticleInformation.Id | moov.ParticleInformation.Position | moov.ParticleInformation.Rotation ) )

	def UpdateParticles( self, step, particleSet ):
		# Update dynamic particles
		self.solver.UpdateParticles( particleSet, lambda pd: \
			SolverUpdater._Update( pd, step, self.Gravity, self.DragComplement, self.attractMagnitude, self.initialPositions, self.particleForces, self.dynamicIdToIndexMap, self.rampMultipliers ), \
			moov.ParticleInformation( moov.ParticleInformation.Id | moov.ParticleInformation.Mass | moov.ParticleInformation.Position | moov.ParticleInformation.Velocity ) )

	def UpdateMasses( self, particleSet, masses ):
		self.solver.UpdateParticles( particleSet, lambda pd: SolverUpdater._UpdateMass( pd, masses, self.dynamicIdToIndexMap ), 
							  moov.ParticleInformation( moov.ParticleInformation.Id | moov.ParticleInformation.Mass ) )


def GetSolverUpdater( solver, tryOpenCLFirst = True ):

	solverUpdater = None
	params = solver.GetParameters()

	if tryOpenCLFirst and params.hasOpenCL:
		# Check if OpenCL kernels compile on the selected platform
		try:
			solverUpdater = SolverUpdater_OpenCL( solver )
		except:
			solverUpdater = SolverUpdater( solver )
	else:
		solverUpdater = SolverUpdater( solver )

	return solverUpdater
	
